package com.tpvlog.dfs.datanode.server;

import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.*;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.LinkedBlockingQueue;

import static com.tpvlog.dfs.datanode.config.DataNodeConfig.DATA_DIR;
import static com.tpvlog.dfs.datanode.config.DataNodeConfig.NIO_PORT;

/**
 * DataNode NIO通信组件
 *
 * @author Ressmix
 */
public class DataNodeNIOServer extends Thread {

    private NameNodeRpcClient rpcClient;

    public static final Integer SEND_FILE = 1;
    public static final Integer READ_FILE = 2;

    private Selector selector;

    private List<LinkedBlockingQueue<SelectionKey>> queues = new ArrayList<>();

    // 缓存的上一次未处理完请求，Key为客户端IP
    private Map<String, InflightRequest> cachedRequestMap = new ConcurrentHashMap<>();
    // 缓存没读取完的请求类型
    private Map<String, ByteBuffer> requestTypeByClient = new ConcurrentHashMap<String, ByteBuffer>();
    // 缓存没读取完的文件名大小
    private Map<String, ByteBuffer> filenameLengthByClient = new ConcurrentHashMap<String, ByteBuffer>();
    // 缓存没读取完的文件名
    private Map<String, ByteBuffer> filenameByClient = new ConcurrentHashMap<String, ByteBuffer>();
    // 缓存没读取完的文件大小
    private Map<String, ByteBuffer> fileLengthByClient = new ConcurrentHashMap<String, ByteBuffer>();
    // 缓存没读取完的文件
    private Map<String, ByteBuffer> fileByClient = new ConcurrentHashMap<String, ByteBuffer>();

    public DataNodeNIOServer(NameNodeRpcClient rpcClient) {
        this.rpcClient = rpcClient;
        init();
    }

    @Override
    public void run() {
        while (true) {
            try {
                // 阻塞等待
                selector.select();
                Iterator<SelectionKey> keysIterator = selector.selectedKeys().iterator();
                while (keysIterator.hasNext()) {
                    SelectionKey key = (SelectionKey) keysIterator.next();
                    keysIterator.remove();
                    handleEvent(key);
                }
            } catch (Throwable t) {
                t.printStackTrace();
            }
        }
    }

    /*-------------------------------------------------PRIVATE METHOD----------------------------------------------*/

    private void init() {
        ServerSocketChannel serverChannel = null;
        try {
            selector = Selector.open();
            serverChannel = ServerSocketChannel.open();
            serverChannel.configureBlocking(false);
            serverChannel.socket().bind(new InetSocketAddress(NIO_PORT), 100);
            serverChannel.register(selector, SelectionKey.OP_ACCEPT);

            // 创建3个缓冲队列
            for (int i = 0; i < 3; i++) {
                queues.add(new LinkedBlockingQueue<SelectionKey>());
            }

            // 创建三个工作线程，每个线程分配一个队列
            for (int i = 0; i < 3; i++) {
                new Worker(queues.get(i)).start();
            }

            System.out.println("NIOServer已经启动，开始监听端口：" + NIO_PORT);
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private void handleEvent(SelectionKey key) throws Exception {
        SocketChannel channel = null;

        try {
            // 1.建立连接
            if (key.isAcceptable()) {
                ServerSocketChannel serverSocketChannel = (ServerSocketChannel) key.channel();
                channel = serverSocketChannel.accept();
                if (channel != null) {
                    channel.configureBlocking(false);
                    channel.register(selector, SelectionKey.OP_READ);
                }
            }
            // 2.读取请求
            else if (key.isReadable()) {
                channel = (SocketChannel) key.channel();
                // 根据客户端IP地址进行hash，即同一个客户端的请求均入同一个队列
                String remoteAddr = channel.getRemoteAddress().toString();
                int queueIndex = remoteAddr.hashCode() % queues.size();
                queues.get(queueIndex).put(key);
            }
        } catch (Throwable t) {
            t.printStackTrace();
            if (channel != null) {
                channel.close();
            }
        }
    }


/**
 * 工作线程
 */
class Worker extends Thread {
    private LinkedBlockingQueue<SelectionKey> queue;

    public Worker(LinkedBlockingQueue<SelectionKey> queue) {
        this.queue = queue;
    }

    @Override
    public void run() {
        while (true) {
            SocketChannel channel = null;
            try {
                // 出队一个元素
                SelectionKey key = queue.take();
                channel = (SocketChannel) key.channel();
                if (!channel.isOpen()) {
                    channel.close();
                    continue;
                }
                handleRequest(channel, key);
            } catch (Exception e) {
                e.printStackTrace();
                if (channel != null) {
                    try {
                        channel.close();
                    } catch (IOException e1) {
                        e1.printStackTrace();
                    }
                }
            }
        }
    }

    private void handleRequest(SocketChannel channel, SelectionKey key) throws Exception {
        String client = channel.getRemoteAddress().toString();
        System.out.println("接收到客户端的请求：" + client);

        // 1.针对请求的拆包问题处理
        if (cachedRequestMap.containsKey(client)) {
            System.out.println("上一次上传文件请求出现拆包问题，本次继续执行文件上传操作......");
            handleSendFileRequest(channel, key);
            return;
        }

        // 2.处理请求类型
        Integer requestType = getRequestType(channel);
        if (requestType == null) {
            return;
        }
        System.out.println("从请求中解析出来请求类型：" + requestType);
        // 上传文件请求
        if (SEND_FILE.equals(requestType)) {
            handleSendFileRequest(channel, key);
        }
        // 下载文件请求
        else if (READ_FILE.equals(requestType)) {
            handleReadFileRequest(channel, key);
        }

    }

    private void handleReadFileRequest(SocketChannel channel, SelectionKey key) throws IOException {
        String client = channel.getRemoteAddress().toString();

        // 从请求中解析文件名
        String filename = getFilename(channel);
        System.out.println("从网络请求中解析出来文件名：" + filename);
        if (filename == null) {
            return;
        }

        File file = new File(DATA_DIR + filename);
        Long fileLength = file.length();

        FileInputStream imageIn = new FileInputStream(DATA_DIR + filename);
        FileChannel imageChannel = imageIn.getChannel();

        // 循环不断的从channel里读取数据，并写入磁盘文件
        ByteBuffer buffer = ByteBuffer.allocate(
                8 + Integer.valueOf(String.valueOf(fileLength)));
        buffer.putLong(fileLength);
        int hasReadImageLength = imageChannel.read(buffer);
        System.out.println("从本次磁盘文件中读取了" + hasReadImageLength + " bytes的数据");

        buffer.rewind();
        int sent = channel.write(buffer);
        System.out.println("将" + sent + " bytes的数据发送给了客户端.....");

        imageChannel.close();
        imageIn.close();

        // 判断一下，如果已经读取完毕，就返回一个成功给客户端
        if (hasReadImageLength == fileLength) {
            System.out.println("文件发送完毕，给客户端: " + client);
            cachedRequestMap.remove(client);
            key.interestOps(key.interestOps() & ~SelectionKey.OP_READ);
        }
    }

    private Integer getRequestType(SocketChannel channel) throws IOException {
        Integer requestType = null;
        String client = channel.getRemoteAddress().toString();

        if (getCachedRequest(client).getRequestType() != null) {
            return getCachedRequest(client).getRequestType();
        }
        // 对请求类型的拆包问题进行处理
        ByteBuffer requestTypeBuffer = null;
        if (requestTypeByClient.containsKey(client)) {
            requestTypeBuffer = requestTypeByClient.get(client);
        } else {
            requestTypeBuffer = ByteBuffer.allocate(4);
        }
        channel.read(requestTypeBuffer);

        if (!requestTypeBuffer.hasRemaining()) {
            requestTypeBuffer.rewind();
            requestType = requestTypeBuffer.getInt();

            requestTypeByClient.remove(client);
            InflightRequest cachedRequest = getCachedRequest(client);
            cachedRequest.setRequestType(requestType);
        } else {
            requestTypeByClient.put(client, requestTypeBuffer);
        }
        return requestType;
    }

    private Long getHasReadFileLength(SocketChannel channel) throws Exception {
        String client = channel.getRemoteAddress().toString();
        if (getCachedRequest(client).getHasReadedSize() != null) {
            return getCachedRequest(client).getHasReadedSize();
        }
        return 0L;
    }

    private void handleSendFileRequest(SocketChannel channel, SelectionKey key) throws Exception {
        String client = channel.getRemoteAddress().toString();

        // 1.从请求中解析文件名
        String filename = getFilename(channel);
        System.out.println("从网络请求中解析出来文件名：" + filename);
        if (filename == null) {
            return;
        }
        // 2.从请求中解析文件大小
        Long fileLength = getFileLength(channel);
        System.out.println("从网络请求中解析出来文件大小：" + fileLength);
        if (fileLength == null) {
            return;
        }
        // 获取已经读取的文件大小
        long hasReadImageLength = getHasReadFileLength(channel);
        System.out.println("初始化已经读取的文件大小：" + hasReadImageLength);

        // 构建针对本地文件的输出流
        FileOutputStream imageOut = null;
        FileChannel imageChannel = null;

        try {
            imageOut = new FileOutputStream(DATA_DIR + "" + filename);
            imageChannel = imageOut.getChannel();
            imageChannel.position(imageChannel.size());
            System.out.println("对本地磁盘文件定位到position=" + imageChannel.size());

            // 循环不断的从channel里读取数据，并写入磁盘文件
            ByteBuffer fileBuffer = null;
            if (fileByClient.containsKey(client)) {
                fileBuffer = fileByClient.get(client);
            } else {
                fileBuffer = ByteBuffer.allocate(Integer.valueOf(String.valueOf(fileLength)));
            }

            hasReadImageLength += channel.read(fileBuffer);
            if (!fileBuffer.hasRemaining()) {
                fileBuffer.rewind();
                int written = imageChannel.write(fileBuffer);
                fileByClient.remove(client);
                System.out.println("本次文件上传完毕，将" + written + " bytes的数据写入本地磁盘文件.......");

                ByteBuffer outBuffer = ByteBuffer.wrap("SUCCESS".getBytes());
                channel.write(outBuffer);
                cachedRequestMap.remove(client);
                System.out.println("文件读取完毕，返回响应给客户端: " + client);

                // 增量上报
                rpcClient.deltaReportDataNodeInfo(filename, hasReadImageLength);
                System.out.println("增量上报收到的文件副本给NameNode节点......");

                key.interestOps(key.interestOps() & ~SelectionKey.OP_READ);
            } else {
                fileByClient.put(client, fileBuffer);
                getCachedRequest(client).setHasReadedSize(hasReadImageLength);
                System.out.println("本次文件上传出现拆包问题，缓存起来，下次继续读取.......");
                return;
            }
        } finally {
            imageChannel.close();
            imageOut.close();
        }
    }

    private Long getFileLength(SocketChannel channel) throws IOException {
        Long fileLength = null;
        String client = channel.getRemoteAddress().toString();

        if (getCachedRequest(client).getFilesize() != null) {
            return getCachedRequest(client).getFilesize();
        } else {
            ByteBuffer fileLengthBuffer = null;
            if (fileLengthByClient.get(client) != null) {
                fileLengthBuffer = fileLengthByClient.get(client);
            } else {
                fileLengthBuffer = ByteBuffer.allocate(8);
            }

            channel.read(fileLengthBuffer);
            if (!fileLengthBuffer.hasRemaining()) {
                fileLengthBuffer.rewind();
                fileLength = fileLengthBuffer.getLong();
                fileLengthByClient.remove(client);
                getCachedRequest(client).setFilesize(fileLength);
            } else {
                fileLengthByClient.put(client, fileLengthBuffer);
            }
        }
        return fileLength;
    }

    private String getFilename(SocketChannel channel) throws IOException {
        String client = channel.getRemoteAddress().toString();
        if (getCachedRequest(client).getFilename() != null) {
            return getCachedRequest(client).getFilename();
        } else {
            Integer filenameLength = null;
            String filename = null;

            // 读取文件名的大小
            if (!filenameByClient.containsKey(client)) {
                ByteBuffer filenameLengthBuffer = null;
                if (filenameLengthByClient.containsKey(client)) {
                    filenameLengthBuffer = filenameLengthByClient.get(client);
                } else {
                    filenameLengthBuffer = ByteBuffer.allocate(4);
                }
                channel.read(filenameLengthBuffer);

                if (!filenameLengthBuffer.hasRemaining()) {
                    filenameLengthBuffer.rewind();
                    filenameLength = filenameLengthBuffer.getInt();
                    filenameLengthByClient.remove(client);
                } else {
                    filenameLengthByClient.put(client, filenameLengthBuffer);
                    return null;
                }
            }

            // 读取文件名
            ByteBuffer filenameBuffer = null;
            if (filenameByClient.containsKey(client)) {
                filenameBuffer = filenameByClient.get(client);
            } else {
                filenameBuffer = ByteBuffer.allocate(filenameLength);
            }
            channel.read(filenameBuffer);

            if (!filenameBuffer.hasRemaining()) {
                filenameBuffer.rewind();
                filename = new String(filenameBuffer.array());
                filenameByClient.remove(client);
            } else {
                filenameByClient.put(client, filenameBuffer);
            }
            return filename;
        }
    }
}


    /**
     * 缓存数据
     */
    class InflightRequest {
        // 文件名，以前缀分隔符开始，比如/dir/enclosure/qq.jpg
        private String filename;
        // 文件总大小
        private Long filesize;
        // 已读取的大小
        private Long hasReadedSize;
        // 请求类型
        private Integer requestType;

        public InflightRequest(String filename, Long imageSize, Long hasReadedSize) {
            this.filename = filename;
            this.filesize = imageSize;
            this.hasReadedSize = hasReadedSize;
        }

        public InflightRequest() {
        }

        @Override
        public String toString() {
            return "InflightRequest{" +
                    "filename='" + filename + '\'' +
                    ", filesize=" + filesize +
                    ", hasReadedSize=" + hasReadedSize +
                    '}';
        }

        public String getFilename() {
            return filename;
        }

        public void setFilename(String filename) {
            this.filename = filename;
        }

        public Long getFilesize() {
            return filesize;
        }

        public void setFilesize(Long filesize) {
            this.filesize = filesize;
        }

        public Long getHasReadedSize() {
            return hasReadedSize;
        }

        public void setHasReadedSize(Long hasReadedSize) {
            this.hasReadedSize = hasReadedSize;
        }

        public Integer getRequestType() {
            return requestType;
        }

        public void setRequestType(Integer requestType) {
            this.requestType = requestType;
        }
    }

    /**
     * 获取缓存的请求
     */
    InflightRequest getCachedRequest(String client) {
        InflightRequest cachedRequest = cachedRequestMap.get(client);
        if (cachedRequest == null) {
            cachedRequest = new InflightRequest();
            cachedRequestMap.put(client, cachedRequest);
        }
        return cachedRequest;
    }
}
